# -*- coding: utf-8 -*-
"""Location: ./mcpgateway/services/observability_service.py
Copyright 2025
SPDX-License-Identifier: Apache-2.0
Authors: Mihai Criveti

Observability Service Implementation.
This module provides OpenTelemetry-style observability for MCP Gateway,
capturing traces, spans, events, and metrics for all operations.

It includes:
- Trace creation and management
- Span tracking with hierarchical nesting
- Event logging within spans
- Metrics collection and storage
- Query and filtering capabilities
- Integration with FastAPI middleware

Examples:
    >>> from mcpgateway.services.observability_service import ObservabilityService  # doctest: +SKIP
    >>> service = ObservabilityService()  # doctest: +SKIP
    >>> trace_id = service.start_trace(db, "GET /tools", http_method="GET", http_url="/tools")  # doctest: +SKIP
    >>> span_id = service.start_span(db, trace_id, "database_query", resource_type="database")  # doctest: +SKIP
    >>> service.end_span(db, span_id, status="ok")  # doctest: +SKIP
    >>> service.end_trace(db, trace_id, status="ok", http_status_code=200)  # doctest: +SKIP
"""

# Standard
from contextlib import contextmanager
from contextvars import ContextVar
from datetime import datetime, timezone
import logging
import re
import traceback
from typing import Any, Dict, List, Optional, Tuple
import uuid

# Third-Party
from sqlalchemy import desc
from sqlalchemy.orm import joinedload, Session

# First-Party
from mcpgateway.db import ObservabilityEvent, ObservabilityMetric, ObservabilitySpan, ObservabilityTrace

logger = logging.getLogger(__name__)

# Context variable for tracking the current trace_id across async calls
current_trace_id: ContextVar[Optional[str]] = ContextVar("current_trace_id", default=None)


def utc_now() -> datetime:
    """Return current UTC time with timezone.

    Returns:
        datetime: Current time in UTC with timezone info
    """
    return datetime.now(timezone.utc)


def ensure_timezone_aware(dt: datetime) -> datetime:
    """Ensure datetime is timezone-aware (UTC).

    SQLite returns naive datetimes even when stored with timezone info.
    This helper ensures consistency for datetime arithmetic.

    Args:
        dt: Datetime that may be naive or aware

    Returns:
        Timezone-aware datetime in UTC
    """
    if dt.tzinfo is None:
        return dt.replace(tzinfo=timezone.utc)
    return dt


def parse_traceparent(traceparent: str) -> Optional[Tuple[str, str, str]]:
    """Parse W3C Trace Context traceparent header.

    Format: version-trace_id-parent_id-trace_flags
    Example: 00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01

    Args:
        traceparent: W3C traceparent header value

    Returns:
        Tuple of (trace_id, parent_span_id, trace_flags) or None if invalid

    Examples:
        >>> parse_traceparent("00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01")  # doctest: +SKIP
        ('0af7651916cd43dd8448eb211c80319c', 'b7ad6b7169203331', '01')
    """
    # W3C Trace Context format: 00-trace_id(32hex)-parent_id(16hex)-flags(2hex)
    pattern = r"^([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})$"
    match = re.match(pattern, traceparent.lower())

    if not match:
        logger.warning(f"Invalid traceparent format: {traceparent}")
        return None

    version, trace_id, parent_id, flags = match.groups()

    # Only support version 00 for now
    if version != "00":
        logger.warning(f"Unsupported traceparent version: {version}")
        return None

    # Validate trace_id and parent_id are not all zeros
    if trace_id == "0" * 32 or parent_id == "0" * 16:
        logger.warning("Invalid traceparent with zero trace_id or parent_id")
        return None

    return (trace_id, parent_id, flags)


def generate_w3c_trace_id() -> str:
    """Generate a W3C compliant trace ID (32 hex characters).

    Returns:
        32-character lowercase hex string

    Examples:
        >>> trace_id = generate_w3c_trace_id()  # doctest: +SKIP
        >>> len(trace_id)  # doctest: +SKIP
        32
    """
    return uuid.uuid4().hex + uuid.uuid4().hex[:16]


def generate_w3c_span_id() -> str:
    """Generate a W3C compliant span ID (16 hex characters).

    Returns:
        16-character lowercase hex string

    Examples:
        >>> span_id = generate_w3c_span_id()  # doctest: +SKIP
        >>> len(span_id)  # doctest: +SKIP
        16
    """
    return uuid.uuid4().hex[:16]


def format_traceparent(trace_id: str, span_id: str, sampled: bool = True) -> str:
    """Format a W3C traceparent header value.

    Args:
        trace_id: 32-character hex trace ID
        span_id: 16-character hex span ID
        sampled: Whether the trace is sampled (affects trace-flags)

    Returns:
        W3C traceparent header value

    Examples:
        >>> format_traceparent("0af7651916cd43dd8448eb211c80319c", "b7ad6b7169203331")  # doctest: +SKIP
        '00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01'
    """
    flags = "01" if sampled else "00"
    return f"00-{trace_id}-{span_id}-{flags}"


class ObservabilityService:
    """Service for managing observability traces, spans, events, and metrics.

    This service provides comprehensive observability capabilities similar to
    OpenTelemetry, allowing tracking of request flows through the system.

    Examples:
        >>> service = ObservabilityService()  # doctest: +SKIP
        >>> trace_id = service.start_trace(db, "POST /tools/invoke")  # doctest: +SKIP
        >>> span_id = service.start_span(db, trace_id, "tool_execution")  # doctest: +SKIP
        >>> service.end_span(db, span_id, status="ok")  # doctest: +SKIP
        >>> service.end_trace(db, trace_id, status="ok")  # doctest: +SKIP
    """

    # ==============================
    # Trace Management
    # ==============================

    def start_trace(
        self,
        db: Session,
        name: str,
        trace_id: Optional[str] = None,
        parent_span_id: Optional[str] = None,
        http_method: Optional[str] = None,
        http_url: Optional[str] = None,
        user_email: Optional[str] = None,
        user_agent: Optional[str] = None,
        ip_address: Optional[str] = None,
        attributes: Optional[Dict[str, Any]] = None,
        resource_attributes: Optional[Dict[str, Any]] = None,
    ) -> str:
        """Start a new trace.

        Args:
            db: Database session
            name: Trace name (e.g., "POST /tools/invoke")
            trace_id: External trace ID (for distributed tracing, W3C format)
            parent_span_id: Parent span ID from upstream service
            http_method: HTTP method (GET, POST, etc.)
            http_url: Full request URL
            user_email: Authenticated user email
            user_agent: Client user agent string
            ip_address: Client IP address
            attributes: Additional trace attributes
            resource_attributes: Resource attributes (service name, version, etc.)

        Returns:
            Trace ID (UUID string or W3C format)

        Examples:
            >>> trace_id = service.start_trace(  # doctest: +SKIP
            ...     db,
            ...     "POST /tools/invoke",
            ...     http_method="POST",
            ...     http_url="https://api.example.com/tools/invoke",
            ...     user_email="user@example.com"
            ... )
        """
        # Use provided trace_id or generate new UUID
        if not trace_id:
            trace_id = str(uuid.uuid4())

        # Add parent context to attributes if provided
        attrs = attributes or {}
        if parent_span_id:
            attrs["parent_span_id"] = parent_span_id

        trace = ObservabilityTrace(
            trace_id=trace_id,
            name=name,
            start_time=utc_now(),
            status="unset",
            http_method=http_method,
            http_url=http_url,
            user_email=user_email,
            user_agent=user_agent,
            ip_address=ip_address,
            attributes=attrs,
            resource_attributes=resource_attributes or {},
            created_at=utc_now(),
        )
        db.add(trace)
        db.commit()
        logger.debug(f"Started trace {trace_id}: {name}")
        return trace_id

    def end_trace(
        self,
        db: Session,
        trace_id: str,
        status: str = "ok",
        status_message: Optional[str] = None,
        http_status_code: Optional[int] = None,
        attributes: Optional[Dict[str, Any]] = None,
    ) -> None:
        """End a trace.

        Args:
            db: Database session
            trace_id: Trace ID to end
            status: Trace status (ok, error)
            status_message: Optional status message
            http_status_code: HTTP response status code
            attributes: Additional attributes to merge

        Examples:
            >>> service.end_trace(  # doctest: +SKIP
            ...     db,
            ...     trace_id,
            ...     status="ok",
            ...     http_status_code=200
            ... )
        """
        trace = db.query(ObservabilityTrace).filter_by(trace_id=trace_id).first()
        if not trace:
            logger.warning(f"Trace {trace_id} not found")
            return

        end_time = utc_now()
        duration_ms = (end_time - ensure_timezone_aware(trace.start_time)).total_seconds() * 1000

        trace.end_time = end_time
        trace.duration_ms = duration_ms
        trace.status = status
        trace.status_message = status_message
        if http_status_code is not None:
            trace.http_status_code = http_status_code
        if attributes:
            trace.attributes = {**(trace.attributes or {}), **attributes}

        db.commit()
        logger.debug(f"Ended trace {trace_id}: {status} ({duration_ms:.2f}ms)")

    def get_trace(self, db: Session, trace_id: str, include_spans: bool = False) -> Optional[ObservabilityTrace]:
        """Get a trace by ID.

        Args:
            db: Database session
            trace_id: Trace ID
            include_spans: Whether to load spans eagerly

        Returns:
            Trace object or None if not found

        Examples:
            >>> trace = service.get_trace(db, trace_id, include_spans=True)  # doctest: +SKIP
            >>> if trace:  # doctest: +SKIP
            ...     print(f"Trace: {trace.name}, Spans: {len(trace.spans)}")  # doctest: +SKIP
        """
        query = db.query(ObservabilityTrace).filter_by(trace_id=trace_id)
        if include_spans:
            query = query.options(joinedload(ObservabilityTrace.spans))
        return query.first()

    # ==============================
    # Span Management
    # ==============================

    def start_span(
        self,
        db: Session,
        trace_id: str,
        name: str,
        parent_span_id: Optional[str] = None,
        kind: str = "internal",
        resource_name: Optional[str] = None,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
        attributes: Optional[Dict[str, Any]] = None,
    ) -> str:
        """Start a new span within a trace.

        Args:
            db: Database session
            trace_id: Parent trace ID
            name: Span name (e.g., "database_query", "tool_invocation")
            parent_span_id: Parent span ID (for nested spans)
            kind: Span kind (internal, server, client, producer, consumer)
            resource_name: Resource name being operated on
            resource_type: Resource type (tool, resource, prompt, etc.)
            resource_id: Resource ID
            attributes: Additional span attributes

        Returns:
            Span ID (UUID string)

        Examples:
            >>> span_id = service.start_span(  # doctest: +SKIP
            ...     db,
            ...     trace_id,
            ...     "tool_invocation",
            ...     resource_type="tool",
            ...     resource_name="get_weather"
            ... )
        """
        span_id = str(uuid.uuid4())
        span = ObservabilitySpan(
            span_id=span_id,
            trace_id=trace_id,
            parent_span_id=parent_span_id,
            name=name,
            kind=kind,
            start_time=utc_now(),
            status="unset",
            resource_name=resource_name,
            resource_type=resource_type,
            resource_id=resource_id,
            attributes=attributes or {},
            created_at=utc_now(),
        )
        db.add(span)
        db.commit()
        logger.debug(f"Started span {span_id}: {name} (trace={trace_id})")
        return span_id

    def end_span(
        self,
        db: Session,
        span_id: str,
        status: str = "ok",
        status_message: Optional[str] = None,
        attributes: Optional[Dict[str, Any]] = None,
    ) -> None:
        """End a span.

        Args:
            db: Database session
            span_id: Span ID to end
            status: Span status (ok, error)
            status_message: Optional status message
            attributes: Additional attributes to merge

        Examples:
            >>> service.end_span(db, span_id, status="ok")  # doctest: +SKIP
        """
        span = db.query(ObservabilitySpan).filter_by(span_id=span_id).first()
        if not span:
            logger.warning(f"Span {span_id} not found")
            return

        end_time = utc_now()
        duration_ms = (end_time - ensure_timezone_aware(span.start_time)).total_seconds() * 1000

        span.end_time = end_time
        span.duration_ms = duration_ms
        span.status = status
        span.status_message = status_message
        if attributes:
            span.attributes = {**(span.attributes or {}), **attributes}

        db.commit()
        logger.debug(f"Ended span {span_id}: {status} ({duration_ms:.2f}ms)")

    @contextmanager
    def trace_span(
        self,
        db: Session,
        trace_id: str,
        name: str,
        parent_span_id: Optional[str] = None,
        resource_type: Optional[str] = None,
        resource_name: Optional[str] = None,
        attributes: Optional[Dict[str, Any]] = None,
    ):
        """Context manager for automatic span lifecycle management.

        Args:
            db: Database session
            trace_id: Parent trace ID
            name: Span name
            parent_span_id: Parent span ID (optional)
            resource_type: Resource type
            resource_name: Resource name
            attributes: Additional attributes

        Yields:
            Span ID

        Raises:
            Exception: Re-raises any exception after logging it in the span

        Examples:
            >>> with service.trace_span(db, trace_id, "database_query") as span_id:  # doctest: +SKIP
            ...     results = db.query(Tool).all()  # doctest: +SKIP
        """
        span_id = self.start_span(db, trace_id, name, parent_span_id, resource_type=resource_type, resource_name=resource_name, attributes=attributes)
        try:
            yield span_id
            self.end_span(db, span_id, status="ok")
        except Exception as e:
            self.end_span(db, span_id, status="error", status_message=str(e))
            self.add_event(db, span_id, "exception", severity="error", message=str(e), exception_type=type(e).__name__, exception_message=str(e), exception_stacktrace=traceback.format_exc())
            raise

    @contextmanager
    def trace_tool_invocation(
        self,
        db: Session,
        tool_name: str,
        arguments: Dict[str, Any],
        integration_type: Optional[str] = None,
    ):
        """Context manager for tracing MCP tool invocations.

        This automatically creates a span for tool execution, capturing timing,
        arguments, results, and errors.

        Args:
            db: Database session
            tool_name: Name of the tool being invoked
            arguments: Tool arguments (will be sanitized)
            integration_type: Integration type (MCP, REST, A2A, etc.)

        Yields:
            Tuple of (span_id, result_dict) - update result_dict with tool results

        Raises:
            Exception: Re-raises any exception from tool invocation after logging

        Examples:
            >>> with service.trace_tool_invocation(db, "weather", {"city": "NYC"}) as (span_id, result):  # doctest: +SKIP
            ...     response = await http_client.post(...)  # doctest: +SKIP
            ...     result["status_code"] = response.status_code  # doctest: +SKIP
            ...     result["response_size"] = len(response.content)  # doctest: +SKIP
        """
        trace_id = current_trace_id.get()
        if not trace_id:
            # No active trace, yield a no-op
            result_dict: Dict[str, Any] = {}
            yield (None, result_dict)
            return

        # Sanitize arguments (remove sensitive data)
        safe_args = {k: ("***REDACTED***" if any(sensitive in k.lower() for sensitive in ["password", "token", "key", "secret"]) else v) for k, v in arguments.items()}

        # Start tool invocation span
        span_id = self.start_span(
            db=db,
            trace_id=trace_id,
            name=f"tool.invoke.{tool_name}",
            kind="client",
            resource_type="tool",
            resource_name=tool_name,
            attributes={
                "tool.name": tool_name,
                "tool.integration_type": integration_type,
                "tool.argument_count": len(arguments),
                "tool.arguments": safe_args,
            },
        )

        result_dict = {}
        try:
            yield (span_id, result_dict)

            # End span with results
            self.end_span(
                db=db,
                span_id=span_id,
                status="ok",
                attributes={
                    "tool.result": result_dict,
                },
            )
        except Exception as e:
            # Log error in span
            self.end_span(db=db, span_id=span_id, status="error", status_message=str(e))

            self.add_event(
                db=db,
                span_id=span_id,
                name="tool.error",
                severity="error",
                message=str(e),
                exception_type=type(e).__name__,
                exception_message=str(e),
                exception_stacktrace=traceback.format_exc(),
            )
            raise

    # ==============================
    # Event Management
    # ==============================

    def add_event(
        self,
        db: Session,
        span_id: str,
        name: str,
        severity: Optional[str] = None,
        message: Optional[str] = None,
        exception_type: Optional[str] = None,
        exception_message: Optional[str] = None,
        exception_stacktrace: Optional[str] = None,
        attributes: Optional[Dict[str, Any]] = None,
    ) -> int:
        """Add an event to a span.

        Args:
            db: Database session
            span_id: Parent span ID
            name: Event name
            severity: Log severity (debug, info, warning, error, critical)
            message: Event message
            exception_type: Exception class name
            exception_message: Exception message
            exception_stacktrace: Exception stacktrace
            attributes: Additional event attributes

        Returns:
            Event ID

        Examples:
            >>> event_id = service.add_event(  # doctest: +SKIP
            ...     db,  # doctest: +SKIP
            ...     span_id,  # doctest: +SKIP
            ...     "database_connection_error",  # doctest: +SKIP
            ...     severity="error",  # doctest: +SKIP
            ...     message="Failed to connect to database"  # doctest: +SKIP
            ... )  # doctest: +SKIP
        """
        event = ObservabilityEvent(
            span_id=span_id,
            name=name,
            timestamp=utc_now(),
            severity=severity,
            message=message,
            exception_type=exception_type,
            exception_message=exception_message,
            exception_stacktrace=exception_stacktrace,
            attributes=attributes or {},
            created_at=utc_now(),
        )
        db.add(event)
        db.commit()
        db.refresh(event)
        logger.debug(f"Added event to span {span_id}: {name}")
        return event.id

    # ==============================
    # Token Usage Tracking
    # ==============================

    def record_token_usage(
        self,
        db: Session,
        span_id: Optional[str] = None,
        trace_id: Optional[str] = None,
        model: Optional[str] = None,
        input_tokens: int = 0,
        output_tokens: int = 0,
        total_tokens: Optional[int] = None,
        estimated_cost_usd: Optional[float] = None,
        provider: Optional[str] = None,
    ) -> None:
        """Record token usage for LLM calls.

        Args:
            db: Database session
            span_id: Span ID to attach token usage to
            trace_id: Trace ID (will use current context if not provided)
            model: Model name (e.g., "gpt-4", "claude-3-opus")
            input_tokens: Number of input/prompt tokens
            output_tokens: Number of output/completion tokens
            total_tokens: Total tokens (calculated if not provided)
            estimated_cost_usd: Estimated cost in USD
            provider: LLM provider (openai, anthropic, etc.)

        Examples:
            >>> service.record_token_usage(  # doctest: +SKIP
            ...     db, span_id="abc123",
            ...     model="gpt-4",
            ...     input_tokens=100,
            ...     output_tokens=50,
            ...     estimated_cost_usd=0.015
            ... )
        """
        if not trace_id:
            trace_id = current_trace_id.get()

        if not trace_id:
            logger.warning("Cannot record token usage: no active trace")
            return

        # Calculate total if not provided
        if total_tokens is None:
            total_tokens = input_tokens + output_tokens

        # Estimate cost if not provided and we have model info
        if estimated_cost_usd is None and model:
            estimated_cost_usd = self._estimate_token_cost(model, input_tokens, output_tokens)

        # Store in span attributes if span_id provided
        if span_id:
            span = db.query(ObservabilitySpan).filter_by(span_id=span_id).first()
            if span:
                attrs = span.attributes or {}
                attrs.update(
                    {
                        "llm.model": model,
                        "llm.provider": provider,
                        "llm.input_tokens": input_tokens,
                        "llm.output_tokens": output_tokens,
                        "llm.total_tokens": total_tokens,
                        "llm.estimated_cost_usd": estimated_cost_usd,
                    }
                )
                span.attributes = attrs
                db.commit()

        # Also record as metrics for aggregation
        if input_tokens > 0:
            self.record_metric(
                db=db,
                name="llm.tokens.input",
                value=float(input_tokens),
                metric_type="counter",
                unit="tokens",
                trace_id=trace_id,
                attributes={"model": model, "provider": provider},
            )

        if output_tokens > 0:
            self.record_metric(
                db=db,
                name="llm.tokens.output",
                value=float(output_tokens),
                metric_type="counter",
                unit="tokens",
                trace_id=trace_id,
                attributes={"model": model, "provider": provider},
            )

        if estimated_cost_usd:
            self.record_metric(
                db=db,
                name="llm.cost",
                value=estimated_cost_usd,
                metric_type="counter",
                unit="usd",
                trace_id=trace_id,
                attributes={"model": model, "provider": provider},
            )

        logger.debug(f"Recorded token usage: {input_tokens} in, {output_tokens} out, ${estimated_cost_usd:.6f}")

    def _estimate_token_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
        """Estimate cost based on model and token counts.

        Pricing as of January 2025 (prices may change).

        Args:
            model: Model name
            input_tokens: Input token count
            output_tokens: Output token count

        Returns:
            Estimated cost in USD
        """
        # Pricing per 1M tokens (input, output)
        pricing = {
            # OpenAI
            "gpt-4": (30.0, 60.0),
            "gpt-4-turbo": (10.0, 30.0),
            "gpt-4o": (2.5, 10.0),
            "gpt-4o-mini": (0.15, 0.60),
            "gpt-3.5-turbo": (0.50, 1.50),
            # Anthropic
            "claude-3-opus": (15.0, 75.0),
            "claude-3-sonnet": (3.0, 15.0),
            "claude-3-haiku": (0.25, 1.25),
            "claude-3.5-sonnet": (3.0, 15.0),
            "claude-3.5-haiku": (0.80, 4.0),
            # Fallback for unknown models
            "default": (1.0, 3.0),
        }

        # Find matching pricing (case-insensitive, partial match)
        model_lower = model.lower()
        input_price, output_price = pricing.get("default")

        for model_key, prices in pricing.items():
            if model_key in model_lower:
                input_price, output_price = prices
                break

        # Calculate cost (pricing is per 1M tokens)
        input_cost = (input_tokens / 1_000_000) * input_price
        output_cost = (output_tokens / 1_000_000) * output_price

        return input_cost + output_cost

    # ==============================
    # Agent-to-Agent (A2A) Tracing
    # ==============================

    @contextmanager
    def trace_a2a_request(
        self,
        db: Session,
        agent_id: str,
        agent_name: Optional[str] = None,
        operation: Optional[str] = None,
        request_data: Optional[Dict[str, Any]] = None,
    ):
        """Context manager for tracing Agent-to-Agent requests.

        This automatically creates a span for A2A communication, capturing timing,
        request/response data, and errors.

        Args:
            db: Database session
            agent_id: Target agent ID
            agent_name: Human-readable agent name
            operation: Operation being performed (e.g., "query", "execute", "status")
            request_data: Request payload (will be sanitized)

        Yields:
            Tuple of (span_id, result_dict) - update result_dict with A2A results

        Raises:
            Exception: Re-raises any exception from A2A call after logging

        Examples:
            >>> with service.trace_a2a_request(db, "agent-123", "WeatherAgent", "query") as (span_id, result):  # doctest: +SKIP
            ...     response = await http_client.post(...)  # doctest: +SKIP
            ...     result["status_code"] = response.status_code  # doctest: +SKIP
            ...     result["response_time_ms"] = 45.2  # doctest: +SKIP
        """
        trace_id = current_trace_id.get()
        if not trace_id:
            # No active trace, yield a no-op
            result_dict: Dict[str, Any] = {}
            yield (None, result_dict)
            return

        # Sanitize request data
        safe_data = {}
        if request_data:
            safe_data = {k: ("***REDACTED***" if any(sensitive in k.lower() for sensitive in ["password", "token", "key", "secret", "auth"]) else v) for k, v in request_data.items()}

        # Start A2A span
        span_id = self.start_span(
            db=db,
            trace_id=trace_id,
            name=f"a2a.call.{agent_name or agent_id}",
            kind="client",
            resource_type="agent",
            resource_name=agent_name or agent_id,
            attributes={
                "a2a.agent_id": agent_id,
                "a2a.agent_name": agent_name,
                "a2a.operation": operation,
                "a2a.request_data": safe_data,
            },
        )

        result_dict = {}
        try:
            yield (span_id, result_dict)

            # End span with results
            self.end_span(
                db=db,
                span_id=span_id,
                status="ok",
                attributes={
                    "a2a.result": result_dict,
                },
            )
        except Exception as e:
            # Log error in span
            self.end_span(db=db, span_id=span_id, status="error", status_message=str(e))

            self.add_event(
                db=db,
                span_id=span_id,
                name="a2a.error",
                severity="error",
                message=str(e),
                exception_type=type(e).__name__,
                exception_message=str(e),
                exception_stacktrace=traceback.format_exc(),
            )
            raise

    # ==============================
    # Transport Metrics
    # ==============================

    def record_transport_activity(
        self,
        db: Session,
        transport_type: str,
        operation: str,
        message_count: int = 1,
        bytes_sent: Optional[int] = None,
        bytes_received: Optional[int] = None,
        connection_id: Optional[str] = None,
        error: Optional[str] = None,
    ) -> None:
        """Record transport-specific activity metrics.

        Args:
            db: Database session
            transport_type: Transport type (sse, websocket, stdio, http)
            operation: Operation type (connect, disconnect, send, receive, error)
            message_count: Number of messages processed
            bytes_sent: Bytes sent (if applicable)
            bytes_received: Bytes received (if applicable)
            connection_id: Connection/session identifier
            error: Error message if operation failed

        Examples:
            >>> service.record_transport_activity(  # doctest: +SKIP
            ...     db, transport_type="sse",
            ...     operation="send",
            ...     message_count=1,
            ...     bytes_sent=1024
            ... )
        """
        trace_id = current_trace_id.get()

        # Record message count
        if message_count > 0:
            self.record_metric(
                db=db,
                name=f"transport.{transport_type}.messages",
                value=float(message_count),
                metric_type="counter",
                unit="messages",
                trace_id=trace_id,
                attributes={
                    "transport": transport_type,
                    "operation": operation,
                    "connection_id": connection_id,
                },
            )

        # Record bytes sent
        if bytes_sent:
            self.record_metric(
                db=db,
                name=f"transport.{transport_type}.bytes_sent",
                value=float(bytes_sent),
                metric_type="counter",
                unit="bytes",
                trace_id=trace_id,
                attributes={
                    "transport": transport_type,
                    "operation": operation,
                    "connection_id": connection_id,
                },
            )

        # Record bytes received
        if bytes_received:
            self.record_metric(
                db=db,
                name=f"transport.{transport_type}.bytes_received",
                value=float(bytes_received),
                metric_type="counter",
                unit="bytes",
                trace_id=trace_id,
                attributes={
                    "transport": transport_type,
                    "operation": operation,
                    "connection_id": connection_id,
                },
            )

        # Record errors
        if error:
            self.record_metric(
                db=db,
                name=f"transport.{transport_type}.errors",
                value=1.0,
                metric_type="counter",
                unit="errors",
                trace_id=trace_id,
                attributes={
                    "transport": transport_type,
                    "operation": operation,
                    "connection_id": connection_id,
                    "error": error,
                },
            )

        logger.debug(f"Recorded {transport_type} transport activity: {operation} ({message_count} messages)")

    # ==============================
    # Metric Management
    # ==============================

    def record_metric(
        self,
        db: Session,
        name: str,
        value: float,
        metric_type: str = "gauge",
        unit: Optional[str] = None,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
        trace_id: Optional[str] = None,
        attributes: Optional[Dict[str, Any]] = None,
    ) -> int:
        """Record a metric.

        Args:
            db: Database session
            name: Metric name (e.g., "http.request.duration")
            value: Metric value
            metric_type: Metric type (counter, gauge, histogram)
            unit: Metric unit (ms, count, bytes, etc.)
            resource_type: Resource type
            resource_id: Resource ID
            trace_id: Associated trace ID
            attributes: Additional metric attributes/labels

        Returns:
            Metric ID

        Examples:
            >>> metric_id = service.record_metric(  # doctest: +SKIP
            ...     db,  # doctest: +SKIP
            ...     "http.request.duration",  # doctest: +SKIP
            ...     123.45,  # doctest: +SKIP
            ...     metric_type="histogram",  # doctest: +SKIP
            ...     unit="ms",  # doctest: +SKIP
            ...     trace_id=trace_id  # doctest: +SKIP
            ... )  # doctest: +SKIP
        """
        metric = ObservabilityMetric(
            name=name,
            value=value,
            metric_type=metric_type,
            timestamp=utc_now(),
            unit=unit,
            resource_type=resource_type,
            resource_id=resource_id,
            trace_id=trace_id,
            attributes=attributes or {},
            created_at=utc_now(),
        )
        db.add(metric)
        db.commit()
        db.refresh(metric)
        logger.debug(f"Recorded metric: {name} = {value} {unit or ''}")
        return metric.id

    # ==============================
    # Query Methods
    # ==============================

    # pylint: disable=too-many-positional-arguments,too-many-arguments,too-many-locals
    def query_traces(
        self,
        db: Session,
        start_time: Optional[datetime] = None,
        end_time: Optional[datetime] = None,
        min_duration_ms: Optional[float] = None,
        max_duration_ms: Optional[float] = None,
        status: Optional[str] = None,
        status_in: Optional[List[str]] = None,
        status_not_in: Optional[List[str]] = None,
        http_status_code: Optional[int] = None,
        http_status_code_in: Optional[List[int]] = None,
        http_method: Optional[str] = None,
        http_method_in: Optional[List[str]] = None,
        user_email: Optional[str] = None,
        user_email_in: Optional[List[str]] = None,
        attribute_filters: Optional[Dict[str, Any]] = None,
        attribute_filters_or: Optional[Dict[str, Any]] = None,
        attribute_search: Optional[str] = None,
        name_contains: Optional[str] = None,
        order_by: str = "start_time_desc",
        limit: int = 100,
        offset: int = 0,
    ) -> List[ObservabilityTrace]:
        """Query traces with advanced filters.

        Supports both simple filters (single value) and list filters (multiple values with OR logic).
        All top-level filters are combined with AND logic unless using _or suffix.

        Args:
            db: Database session
            start_time: Filter traces after this time
            end_time: Filter traces before this time
            min_duration_ms: Filter traces with duration >= this value (milliseconds)
            max_duration_ms: Filter traces with duration <= this value (milliseconds)
            status: Filter by single status (ok, error)
            status_in: Filter by multiple statuses (OR logic)
            status_not_in: Exclude these statuses (NOT logic)
            http_status_code: Filter by single HTTP status code
            http_status_code_in: Filter by multiple HTTP status codes (OR logic)
            http_method: Filter by single HTTP method (GET, POST, etc.)
            http_method_in: Filter by multiple HTTP methods (OR logic)
            user_email: Filter by single user email
            user_email_in: Filter by multiple user emails (OR logic)
            attribute_filters: JSON attribute filters (AND logic - all must match)
            attribute_filters_or: JSON attribute filters (OR logic - any must match)
            attribute_search: Free-text search within JSON attributes (partial match)
            name_contains: Filter traces where name contains this substring
            order_by: Sort order (start_time_desc, start_time_asc, duration_desc, duration_asc)
            limit: Maximum results (1-1000)
            offset: Result offset

        Returns:
            List of traces

        Raises:
            ValueError: If invalid parameters are provided

        Examples:
            >>> # Find slow errors from multiple endpoints
            >>> traces = service.query_traces(  # doctest: +SKIP
            ...     db,
            ...     status="error",
            ...     min_duration_ms=100.0,
            ...     http_method_in=["POST", "PUT"],
            ...     attribute_filters={"http.route": "/api/tools"},
            ...     limit=50
            ... )
            >>> # Exclude health checks and find slow requests
            >>> traces = service.query_traces(  # doctest: +SKIP
            ...     db,
            ...     min_duration_ms=1000.0,
            ...     name_contains="api",
            ...     status_not_in=["ok"],
            ...     order_by="duration_desc"
            ... )
        """
        # Third-Party
        # pylint: disable=import-outside-toplevel
        from sqlalchemy import cast, or_, String

        # pylint: enable=import-outside-toplevel
        # Validate limit
        if limit < 1 or limit > 1000:
            raise ValueError("limit must be between 1 and 1000")

        # Validate order_by
        valid_orders = ["start_time_desc", "start_time_asc", "duration_desc", "duration_asc"]
        if order_by not in valid_orders:
            raise ValueError(f"order_by must be one of: {', '.join(valid_orders)}")

        query = db.query(ObservabilityTrace)

        # Time range filters
        if start_time:
            query = query.filter(ObservabilityTrace.start_time >= start_time)
        if end_time:
            query = query.filter(ObservabilityTrace.start_time <= end_time)

        # Duration filters
        if min_duration_ms is not None:
            query = query.filter(ObservabilityTrace.duration_ms >= min_duration_ms)
        if max_duration_ms is not None:
            query = query.filter(ObservabilityTrace.duration_ms <= max_duration_ms)

        # Status filters (with OR and NOT support)
        if status:
            query = query.filter(ObservabilityTrace.status == status)
        if status_in:
            query = query.filter(ObservabilityTrace.status.in_(status_in))
        if status_not_in:
            query = query.filter(~ObservabilityTrace.status.in_(status_not_in))

        # HTTP status code filters (with OR support)
        if http_status_code:
            query = query.filter(ObservabilityTrace.http_status_code == http_status_code)
        if http_status_code_in:
            query = query.filter(ObservabilityTrace.http_status_code.in_(http_status_code_in))

        # HTTP method filters (with OR support)
        if http_method:
            query = query.filter(ObservabilityTrace.http_method == http_method)
        if http_method_in:
            query = query.filter(ObservabilityTrace.http_method.in_(http_method_in))

        # User email filters (with OR support)
        if user_email:
            query = query.filter(ObservabilityTrace.user_email == user_email)
        if user_email_in:
            query = query.filter(ObservabilityTrace.user_email.in_(user_email_in))

        # Name substring filter
        if name_contains:
            query = query.filter(ObservabilityTrace.name.ilike(f"%{name_contains}%"))

        # Attribute-based filtering with AND logic (all filters must match)
        if attribute_filters:
            for key, value in attribute_filters.items():
                # Use JSON path access for filtering
                # Supports both SQLite (via json_extract) and PostgreSQL (via ->>)
                query = query.filter(ObservabilityTrace.attributes[key].astext == str(value))

        # Attribute-based filtering with OR logic (any filter must match)
        if attribute_filters_or:
            or_conditions = []
            for key, value in attribute_filters_or.items():
                or_conditions.append(ObservabilityTrace.attributes[key].astext == str(value))
            if or_conditions:
                query = query.filter(or_(*or_conditions))

        # Free-text search across all attribute values
        if attribute_search:
            # Cast JSON attributes to text and search for substring
            # Works with both SQLite and PostgreSQL
            # Escape special characters to prevent SQL injection
            safe_search = attribute_search.replace("%", "\\%").replace("_", "\\_")
            query = query.filter(cast(ObservabilityTrace.attributes, String).ilike(f"%{safe_search}%"))

        # Apply ordering
        if order_by == "start_time_desc":
            query = query.order_by(desc(ObservabilityTrace.start_time))
        elif order_by == "start_time_asc":
            query = query.order_by(ObservabilityTrace.start_time)
        elif order_by == "duration_desc":
            query = query.order_by(desc(ObservabilityTrace.duration_ms))
        elif order_by == "duration_asc":
            query = query.order_by(ObservabilityTrace.duration_ms)

        # Apply pagination
        query = query.limit(limit).offset(offset)

        return query.all()

    # pylint: disable=too-many-positional-arguments,too-many-arguments,too-many-locals
    def query_spans(
        self,
        db: Session,
        trace_id: Optional[str] = None,
        trace_id_in: Optional[List[str]] = None,
        resource_type: Optional[str] = None,
        resource_type_in: Optional[List[str]] = None,
        resource_name: Optional[str] = None,
        resource_name_in: Optional[List[str]] = None,
        name_contains: Optional[str] = None,
        kind: Optional[str] = None,
        kind_in: Optional[List[str]] = None,
        status: Optional[str] = None,
        status_in: Optional[List[str]] = None,
        status_not_in: Optional[List[str]] = None,
        start_time: Optional[datetime] = None,
        end_time: Optional[datetime] = None,
        min_duration_ms: Optional[float] = None,
        max_duration_ms: Optional[float] = None,
        attribute_filters: Optional[Dict[str, Any]] = None,
        attribute_search: Optional[str] = None,
        order_by: str = "start_time_desc",
        limit: int = 100,
        offset: int = 0,
    ) -> List[ObservabilitySpan]:
        """Query spans with advanced filters.

        Supports filtering by trace, resource, kind, status, duration, and attributes.
        All top-level filters are combined with AND logic. List filters use OR logic.

        Args:
            db: Database session
            trace_id: Filter by single trace ID
            trace_id_in: Filter by multiple trace IDs (OR logic)
            resource_type: Filter by single resource type (tool, database, plugin, etc.)
            resource_type_in: Filter by multiple resource types (OR logic)
            resource_name: Filter by single resource name
            resource_name_in: Filter by multiple resource names (OR logic)
            name_contains: Filter spans where name contains this substring
            kind: Filter by span kind (client, server, internal)
            kind_in: Filter by multiple kinds (OR logic)
            status: Filter by single status (ok, error)
            status_in: Filter by multiple statuses (OR logic)
            status_not_in: Exclude these statuses (NOT logic)
            start_time: Filter spans after this time
            end_time: Filter spans before this time
            min_duration_ms: Filter spans with duration >= this value (milliseconds)
            max_duration_ms: Filter spans with duration <= this value (milliseconds)
            attribute_filters: JSON attribute filters (AND logic)
            attribute_search: Free-text search within JSON attributes
            order_by: Sort order (start_time_desc, start_time_asc, duration_desc, duration_asc)
            limit: Maximum results (1-1000)
            offset: Result offset

        Returns:
            List of spans

        Raises:
            ValueError: If invalid parameters are provided

        Examples:
            >>> # Find slow database queries
            >>> spans = service.query_spans(  # doctest: +SKIP
            ...     db,
            ...     resource_type="database",
            ...     min_duration_ms=100.0,
            ...     order_by="duration_desc",
            ...     limit=50
            ... )
            >>> # Find tool invocation errors
            >>> spans = service.query_spans(  # doctest: +SKIP
            ...     db,
            ...     resource_type="tool",
            ...     status="error",
            ...     name_contains="invoke"
            ... )
        """
        # Third-Party
        # pylint: disable=import-outside-toplevel
        from sqlalchemy import cast, String

        # pylint: enable=import-outside-toplevel
        # Validate limit
        if limit < 1 or limit > 1000:
            raise ValueError("limit must be between 1 and 1000")

        # Validate order_by
        valid_orders = ["start_time_desc", "start_time_asc", "duration_desc", "duration_asc"]
        if order_by not in valid_orders:
            raise ValueError(f"order_by must be one of: {', '.join(valid_orders)}")

        query = db.query(ObservabilitySpan)

        # Trace ID filters (with OR support)
        if trace_id:
            query = query.filter(ObservabilitySpan.trace_id == trace_id)
        if trace_id_in:
            query = query.filter(ObservabilitySpan.trace_id.in_(trace_id_in))

        # Resource type filters (with OR support)
        if resource_type:
            query = query.filter(ObservabilitySpan.resource_type == resource_type)
        if resource_type_in:
            query = query.filter(ObservabilitySpan.resource_type.in_(resource_type_in))

        # Resource name filters (with OR support)
        if resource_name:
            query = query.filter(ObservabilitySpan.resource_name == resource_name)
        if resource_name_in:
            query = query.filter(ObservabilitySpan.resource_name.in_(resource_name_in))

        # Name substring filter
        if name_contains:
            query = query.filter(ObservabilitySpan.name.ilike(f"%{name_contains}%"))

        # Kind filters (with OR support)
        if kind:
            query = query.filter(ObservabilitySpan.kind == kind)
        if kind_in:
            query = query.filter(ObservabilitySpan.kind.in_(kind_in))

        # Status filters (with OR and NOT support)
        if status:
            query = query.filter(ObservabilitySpan.status == status)
        if status_in:
            query = query.filter(ObservabilitySpan.status.in_(status_in))
        if status_not_in:
            query = query.filter(~ObservabilitySpan.status.in_(status_not_in))

        # Time range filters
        if start_time:
            query = query.filter(ObservabilitySpan.start_time >= start_time)
        if end_time:
            query = query.filter(ObservabilitySpan.start_time <= end_time)

        # Duration filters
        if min_duration_ms is not None:
            query = query.filter(ObservabilitySpan.duration_ms >= min_duration_ms)
        if max_duration_ms is not None:
            query = query.filter(ObservabilitySpan.duration_ms <= max_duration_ms)

        # Attribute-based filtering with AND logic
        if attribute_filters:
            for key, value in attribute_filters.items():
                query = query.filter(ObservabilitySpan.attributes[key].astext == str(value))

        # Free-text search across all attribute values
        if attribute_search:
            safe_search = attribute_search.replace("%", "\\%").replace("_", "\\_")
            query = query.filter(cast(ObservabilitySpan.attributes, String).ilike(f"%{safe_search}%"))

        # Apply ordering
        if order_by == "start_time_desc":
            query = query.order_by(desc(ObservabilitySpan.start_time))
        elif order_by == "start_time_asc":
            query = query.order_by(ObservabilitySpan.start_time)
        elif order_by == "duration_desc":
            query = query.order_by(desc(ObservabilitySpan.duration_ms))
        elif order_by == "duration_asc":
            query = query.order_by(ObservabilitySpan.duration_ms)

        # Apply pagination
        query = query.limit(limit).offset(offset)

        return query.all()

    def get_trace_with_spans(self, db: Session, trace_id: str) -> Optional[ObservabilityTrace]:
        """Get a complete trace with all spans and events.

        Args:
            db: Database session
            trace_id: Trace ID

        Returns:
            Trace with spans and events loaded

        Examples:
            >>> trace = service.get_trace_with_spans(db, trace_id)  # doctest: +SKIP
            >>> if trace:  # doctest: +SKIP
            ...     for span in trace.spans:  # doctest: +SKIP
            ...         print(f"Span: {span.name}, Events: {len(span.events)}")  # doctest: +SKIP
        """
        return db.query(ObservabilityTrace).filter_by(trace_id=trace_id).options(joinedload(ObservabilityTrace.spans).joinedload(ObservabilitySpan.events)).first()

    def delete_old_traces(self, db: Session, before_time: datetime) -> int:
        """Delete traces older than a given time.

        Args:
            db: Database session
            before_time: Delete traces before this time

        Returns:
            Number of traces deleted

        Examples:
            >>> from datetime import timedelta  # doctest: +SKIP
            >>> cutoff = utc_now() - timedelta(days=30)  # doctest: +SKIP
            >>> deleted = service.delete_old_traces(db, cutoff)  # doctest: +SKIP
            >>> print(f"Deleted {deleted} old traces")  # doctest: +SKIP
        """
        deleted = db.query(ObservabilityTrace).filter(ObservabilityTrace.start_time < before_time).delete()
        db.commit()
        logger.info(f"Deleted {deleted} traces older than {before_time}")
        return deleted
